Download repoΒΆ

In [36]:
import os
import shutil
import zipfile
import urllib.request
In [37]:
REPO_ZIP_FILE = 'LinearizedNNs-master.zip'
urllib.request.urlretrieve('https://github.com/maxkvant/LinearizedNNs/archive/master.zip', REPO_ZIP_FILE)

REPO_PATH = "LinearizedNNs-master"
if os.path.exists(REPO_PATH):
    shutil.rmtree(REPO_PATH)
    
with zipfile.ZipFile(REPO_ZIP_FILE, 'r') as zip_ref:
    zip_ref.extractall('.')
    
assert os.path.exists(REPO_PATH)
/py-env/platform-env/lib/python3.7/site-packages/ml_platform/user_messages.py:13: UserWarning: 
The following variables cannot be serialized: zip_ref

Please note that these variables can be lost in the next working session
  warnings.warn(message)

ImportsΒΆ

In [38]:
import sys
sys.path.append(f"{REPO_PATH}/src")
In [58]:
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torchvision import transforms, datasets
from torchvision.datasets import FashionMNIST

from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.linear_model import RidgeClassifier
from sklearn.decomposition import PCA

from xgboost import XGBClassifier

from pytorch_impl.nns import ResNet, FCN, CNN
from pytorch_impl.nns import warm_up_batch_norm
from pytorch_impl.estimators import LinearizedSgdEstimator, SgdEstimator, MatrixExpEstimator
from pytorch_impl import ClassifierTraining
from pytorch_impl.matrix_exp import matrix_exp, compute_exp_term
from pytorch_impl.nns.utils import to_one_hot

FashionMNISTΒΆ

In [51]:
torch.manual_seed(0)

if torch.cuda.is_available():
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')

print('Torch version: {}'.format(torch.__version__))
print('Device: {}'.format(device))

D = 28
num_classes = 10

train_loader = torch.utils.data.DataLoader(
    FashionMNIST(root='.', train=True, download=True,
          transform=transforms.ToTensor()),
    batch_size=4096, shuffle=True, pin_memory=True)

test_loader = torch.utils.data.DataLoader(
    FashionMNIST(root='.', train=False, transform=transforms.ToTensor()),
    batch_size=4096, shuffle=True, pin_memory=True)
Torch version: 1.4.0
Device: cuda:0
In [58]:
a = 30 / 180 * np.pi
M = torch.tensor([[0, -1], [1, 0]]) * a
matrix_exp(M, device)
Out[58]:
tensor([[ 0.8660, -0.5000],
        [ 0.5000,  0.8660]], device='cuda:0')
In [59]:
a = 30 / 180 * np.pi
M = torch.tensor([[0, -1], [1, 0]]) * a
M_clone = M.clone().to(device)
torch.matmul(M_clone, compute_exp_term(M, device)) + torch.eye(2).to(device)
Out[59]:
tensor([[ 0.8660, -0.5000],
        [ 0.5000,  0.8660]], device='cuda:0')
In [60]:
CNN(1, input_channels=1)
Out[60]:
CNN(
  (layers): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU()
  )
  (classifier): Linear(in_features=64, out_features=1, bias=False)
)
In [61]:
_, (X, y) = next(enumerate(train_loader))
X.size()
Out[61]:
torch.Size([4096, 1, 28, 28])
In [62]:
model = CNN(1, input_channels=1, num_channels=256).to(device)
warm_up_batch_norm(model, train_loader, device)
BatchNorm2d
BatchNorm2d
BatchNorm2d
BatchNorm2d
Out[62]:
CNN(
  (layers): Sequential(
    (0): Conv2d(1, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU()
  )
  (classifier): Linear(in_features=256, out_features=1, bias=False)
)
In [63]:
estimator = MatrixExpEstimator(model, num_classes, device, learning_rate=1e1, momentum=0.)
_, (X, y) = next(enumerate(train_loader))
X, y = X.to(device), y.to(device)
estimator.fit(X, y)
ClassifierTraining(estimator, device).get_accuracy(test_loader)
accuracy 0.10229, loss 5.00000
computing grads ... 86s
exponentiating kernel matrix ... 86s
Out[63]:
0.8320434920024127
In [64]:
def get_estimator(model, num_classes):
    return MatrixExpEstimator(model, num_classes, device, learning_rate=10.)

ws_sum = get_estimator(model, num_classes).ws.detach()

prev_size = None
batches = 0
for batch_id, (X, y) in enumerate(train_loader):
    if (prev_size is not None) and prev_size != len(X):
        break
    prev_size = len(X)
    X, y = X.to(device), y.to(device)
        
    estimator = get_estimator(model, num_classes)
    estimator.fit(X, y)
    
    ws_sum += estimator.ws.detach()
    batches += 1
    
estimator = get_estimator(model, num_classes)
estimator.ws = ws_sum / batches
accuracy 0.09937, loss 5.00000
computing grads ... 86s
exponentiating kernel matrix ... 86s
accuracy 0.10181, loss 5.00000
computing grads ... 86s
exponentiating kernel matrix ... 86s
accuracy 0.09863, loss 5.00000
computing grads ... 86s
exponentiating kernel matrix ... 86s
accuracy 0.10059, loss 5.00000
computing grads ... 86s
exponentiating kernel matrix ... 86s
accuracy 0.09766, loss 5.00000
computing grads ... 86s
exponentiating kernel matrix ... 86s
accuracy 0.09644, loss 5.00000
computing grads ... 86s
exponentiating kernel matrix ... 86s
accuracy 0.10400, loss 5.00000
computing grads ... 86s
exponentiating kernel matrix ... 86s
accuracy 0.10010, loss 5.00000
computing grads ... 86s
exponentiating kernel matrix ... 86s
accuracy 0.10425, loss 5.00000
computing grads ... 86s
exponentiating kernel matrix ... 86s
accuracy 0.09692, loss 5.00000
computing grads ... 86s
exponentiating kernel matrix ... 86s
accuracy 0.09888, loss 5.00000
computing grads ... 86s
exponentiating kernel matrix ... 86s
accuracy 0.09473, loss 5.00000
computing grads ... 86s
exponentiating kernel matrix ... 86s
accuracy 0.10376, loss 5.00000
computing grads ... 86s
exponentiating kernel matrix ... 86s
accuracy 0.10645, loss 5.00000
computing grads ... 86s
exponentiating kernel matrix ... 86s
In [65]:
ClassifierTraining(estimator, device).get_accuracy(test_loader)
Out[65]:
0.8562371797549228
In [80]:
model = CNN(10, input_channels=1).to(device)
warm_up_batch_norm(model, train_loader, device)

learning_rate = .005

estimator = SgdEstimator(model, nn.CrossEntropyLoss(), learning_rate)
training  = ClassifierTraining(estimator, device)

training.train(train_loader, test_loader, num_epochs=200, learning_rate=learning_rate)
BatchNorm2d
BatchNorm2d
BatchNorm2d
BatchNorm2d
epoch 0/200, 0s since start
epoch 1/200, 8s since start
epoch 2/200, 17s since start
epoch 3/200, 25s since start
epoch 4/200, 33s since start
epoch 5/200, 41s since start
epoch 6/200, 50s since start
epoch 7/200, 58s since start
epoch 8/200, 66s since start
epoch 9/200, 75s since start
epoch 10/200, 83s since start
epoch 11/200, 91s since start
epoch 12/200, 99s since start
epoch 13/200, 108s since start
epoch 14/200, 116s since start
epoch 15/200, 124s since start
epoch 16/200, 133s since start
epoch 17/200, 141s since start
epoch 18/200, 149s since start
epoch 19/200, 157s since start
epoch 20/200, 166s since start
epoch 21/200, 174s since start
epoch 22/200, 182s since start
epoch 23/200, 190s since start
epoch 24/200, 198s since start
epoch 25/200, 207s since start
epoch 26/200, 215s since start
epoch 27/200, 223s since start
epoch 28/200, 232s since start
epoch 29/200, 240s since start
epoch 30/200, 248s since start
epoch 31/200, 257s since start
epoch 32/200, 265s since start
epoch 33/200, 273s since start
epoch 34/200, 281s since start
epoch 35/200, 290s since start
epoch 36/200, 298s since start
epoch 37/200, 306s since start
epoch 38/200, 315s since start
epoch 39/200, 323s since start
epoch 40/200, 331s since start
epoch 41/200, 340s since start
epoch 42/200, 348s since start
epoch 43/200, 356s since start
epoch 44/200, 364s since start
epoch 45/200, 373s since start
epoch 46/200, 381s since start
epoch 47/200, 389s since start
epoch 48/200, 398s since start
epoch 49/200, 406s since start
epoch 50/200, 414s since start
epoch 51/200, 423s since start
epoch 52/200, 431s since start
epoch 53/200, 439s since start
epoch 54/200, 447s since start
epoch 55/200, 456s since start
epoch 56/200, 464s since start
epoch 57/200, 472s since start
epoch 58/200, 481s since start
epoch 59/200, 489s since start
epoch 60/200, 497s since start
epoch 61/200, 505s since start
epoch 62/200, 514s since start
epoch 63/200, 522s since start
epoch 64/200, 530s since start
epoch 65/200, 539s since start
epoch 66/200, 547s since start
epoch 67/200, 555s since start
epoch 68/200, 563s since start
epoch 69/200, 572s since start
epoch 70/200, 580s since start
epoch 71/200, 588s since start
epoch 72/200, 597s since start
epoch 73/200, 605s since start
epoch 74/200, 614s since start
epoch 75/200, 622s since start
epoch 76/200, 630s since start
epoch 77/200, 639s since start
epoch 78/200, 647s since start
epoch 79/200, 655s since start
epoch 80/200, 663s since start
epoch 81/200, 672s since start
epoch 82/200, 680s since start
epoch 83/200, 688s since start
epoch 84/200, 697s since start
epoch 85/200, 705s since start
epoch 86/200, 714s since start
epoch 87/200, 722s since start
epoch 88/200, 730s since start
epoch 89/200, 738s since start
epoch 90/200, 747s since start
epoch 91/200, 755s since start
epoch 92/200, 763s since start
epoch 93/200, 772s since start
epoch 94/200, 780s since start
epoch 95/200, 788s since start
epoch 96/200, 796s since start
epoch 97/200, 805s since start
epoch 98/200, 813s since start
epoch 99/200, 821s since start
epoch 100/200, 830s since start
epoch 101/200, 838s since start
epoch 102/200, 846s since start
epoch 103/200, 855s since start
epoch 104/200, 863s since start
epoch 105/200, 871s since start
epoch 106/200, 880s since start
epoch 107/200, 888s since start
epoch 108/200, 896s since start
epoch 109/200, 905s since start
epoch 110/200, 913s since start
epoch 111/200, 921s since start
epoch 112/200, 929s since start
epoch 113/200, 938s since start
epoch 114/200, 946s since start
epoch 115/200, 955s since start
epoch 116/200, 963s since start
epoch 117/200, 971s since start
epoch 118/200, 979s since start
epoch 119/200, 987s since start
epoch 120/200, 996s since start
epoch 121/200, 1004s since start
epoch 122/200, 1013s since start
epoch 123/200, 1021s since start
epoch 124/200, 1029s since start
epoch 125/200, 1037s since start
epoch 126/200, 1046s since start
epoch 127/200, 1054s since start
epoch 128/200, 1062s since start
epoch 129/200, 1071s since start
epoch 130/200, 1079s since start
epoch 131/200, 1087s since start
epoch 132/200, 1096s since start
epoch 133/200, 1104s since start
epoch 134/200, 1112s since start
epoch 135/200, 1121s since start
epoch 136/200, 1129s since start
epoch 137/200, 1137s since start
epoch 138/200, 1146s since start
epoch 139/200, 1154s since start
epoch 140/200, 1162s since start
epoch 141/200, 1171s since start
epoch 142/200, 1179s since start
epoch 143/200, 1188s since start
epoch 144/200, 1196s since start
epoch 145/200, 1204s since start
epoch 146/200, 1213s since start
epoch 147/200, 1221s since start
epoch 148/200, 1229s since start
epoch 149/200, 1238s since start
epoch 150/200, 1246s since start
epoch 151/200, 1254s since start
epoch 152/200, 1263s since start
epoch 153/200, 1271s since start
epoch 154/200, 1279s since start
epoch 155/200, 1288s since start
epoch 156/200, 1296s since start
epoch 157/200, 1305s since start
epoch 158/200, 1313s since start
epoch 159/200, 1321s since start
epoch 160/200, 1330s since start
epoch 161/200, 1338s since start
epoch 162/200, 1346s since start
epoch 163/200, 1355s since start
epoch 164/200, 1363s since start
epoch 165/200, 1371s since start
epoch 166/200, 1379s since start
epoch 167/200, 1388s since start
epoch 168/200, 1396s since start
epoch 169/200, 1404s since start
epoch 170/200, 1412s since start
epoch 171/200, 1421s since start
epoch 172/200, 1429s since start
epoch 173/200, 1437s since start
epoch 174/200, 1446s since start
epoch 175/200, 1454s since start
epoch 176/200, 1462s since start
epoch 177/200, 1471s since start
epoch 178/200, 1479s since start
epoch 179/200, 1487s since start
epoch 180/200, 1496s since start
epoch 181/200, 1504s since start
epoch 182/200, 1512s since start
epoch 183/200, 1520s since start
epoch 184/200, 1529s since start
epoch 185/200, 1537s since start
epoch 186/200, 1546s since start
epoch 187/200, 1554s since start
epoch 188/200, 1562s since start
epoch 189/200, 1570s since start
epoch 190/200, 1579s since start
epoch 191/200, 1587s since start
epoch 192/200, 1595s since start
epoch 193/200, 1604s since start
epoch 194/200, 1612s since start
epoch 195/200, 1620s since start
epoch 196/200, 1629s since start
epoch 197/200, 1637s since start
epoch 198/200, 1645s since start
epoch 199/200, 1653s since start
training took 1662s
test_accuracy 0.849
In [12]:
learning_rate = .02
In [13]:
linearized_estimator = LinearizedSgdEstimator(FCN(1, D * D).to(device), num_classes, nn.CrossEntropyLoss(), learning_rate)
linearized_training  = ClassifierTraining(linearized_estimator, device)

linearized_training.train(train_loader, test_loader, num_epochs=10, learning_rate=learning_rate)
epoch 0/10, 0s since start
epoch 1/10, 81s since start
epoch 2/10, 162s since start
epoch 3/10, 243s since start
epoch 4/10, 324s since start
epoch 5/10, 405s since start
epoch 6/10, 486s since start
epoch 7/10, 566s since start
epoch 8/10, 647s since start
epoch 9/10, 728s since start
training took 809s
test_accuracy 0.798
In [14]:
estimator = SgdEstimator(FCN(10, D * D).to(device), nn.CrossEntropyLoss(), learning_rate)
training  = ClassifierTraining(estimator, device)

training.train(train_loader, test_loader, num_epochs=10, learning_rate=learning_rate)
epoch 0/10, 0s since start
epoch 1/10, 8s since start
epoch 2/10, 15s since start
epoch 3/10, 23s since start
epoch 4/10, 30s since start
epoch 5/10, 38s since start
epoch 6/10, 45s since start
epoch 7/10, 52s since start
epoch 8/10, 60s since start
epoch 9/10, 67s since start
training took 75s
test_accuracy 0.825
In [ ]:
_, (X, y) = next(enumerate(train_loader))
X, y = X.to(device), y.to(device)


estimator.fit(X, y)

print(estimator.predict(X).size())

((estimator.predict(X) - to_one_hot(y, num_classes).to(device)) ** 2).mean()
In [ ]:
estimator
In [ ]:
_, (X, y) = next(enumerate(test_loader))
X, y = X.to(device), y.to(device)

(torch.argmax(estimator.predict(X), dim=1) == y).double().mean()

Cifar 10ΒΆ

In [70]:
torch.manual_seed(0)

if torch.cuda.is_available():
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')

cifar10_stats = {
    "mean" : (0.4914, 0.4822, 0.4465),
    "std"  : (0.24705882352941178, 0.24352941176470588, 0.2615686274509804),
}

transform_train = transforms.Compose([
    transforms.Lambda(lambda x: np.asarray(x)),
    transforms.Lambda(lambda x: np.pad(x, [(4, 4), (4, 4), (0, 0)], mode='reflect')),
    transforms.Lambda(lambda x: Image.fromarray(x)),
    transforms.RandomCrop(32),
    
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(cifar10_stats['mean'], cifar10_stats['std']),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cifar10_stats['mean'], cifar10_stats['std']),
])

train_loader = torch.utils.data.DataLoader(
                  datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train), # change back to 
               batch_size=2048, shuffle=True, pin_memory=True)

test_loader  = torch.utils.data.DataLoader(
                  datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test),
               batch_size=2048, shuffle=True, pin_memory=True)

device
Files already downloaded and verified
Files already downloaded and verified
Out[70]:
device(type='cuda', index=0)
In [55]:
#### TODO